-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
[Attention] Optimize FlashInfer MetadataBuilder Build call #21137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Attention] Optimize FlashInfer MetadataBuilder Build call #21137
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a significant and well-executed refactoring of the attention backend infrastructure. The primary goal of decoupling the metadata builders from the model runner has been achieved, which improves modularity and maintainability. The optimization for FlashInfer by preparing metadata on the CPU is a key improvement and has been implemented correctly.
The introduction of CommonAttentionMetadata
as a unified data structure is a solid design choice that simplifies the data flow to the attention backends. The refactoring of the speculative decoding logic, particularly in vllm/v1/spec_decode/eagle.py
, to remove the Triton kernel in favor of a more readable PyTorch/NumPy implementation is a notable improvement.
The addition of a comprehensive test suite in tests/v1/attention/test_attention_backends.py
is excellent. It provides strong validation for the correctness of this large-scale refactoring by comparing various backends against a reference implementation under realistic conditions.
Overall, the changes are of high quality and represent a positive step forward for the codebase. I have not identified any issues of high
or critical
severity.
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Optimize V1 FlashInfer backend to use CPU host buffers - Replace GPU-to-CPU transfers with direct CPU tensor construction - Build planning tensors from existing CommonAttentionMetadata CPU buffers - Reduce from 6x to 1x .cpu() calls during FlashInfer planning - Fix test mocks to handle correct argument count - Maintain compatibility with GPUModelRunner and FlashInfer V1 backend Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> dont transfer block table Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> optimize Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
87ccacf
to
8af5f3b
Compare
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW why don't we use Numpy instead of PyTorch CPU tensors? Except for some edge cases, Numpy is usually faster in my experience.
Could we still pass the device tensors to Flashinfer's plan() rather than host tensors? Because we might want to support full cudagraph of Flashinfer in the future (currently implemented in #20059 in rough), which requires managing device-side persistent buffers that can be reused across different decode wrappers. Here, one decode wrapper corresponds to a runtime shape that needs to be captured. Also, if we pass the host tensors to the wrapper, it seems that H2D transfers still exist. If I remember correctly, Sglang's implementation overrides the plan functions that still pass host-side persistent buffers, and also explicitly avoids certain D2H transfers. Hope it's helpful! @LucasWilkinson |
Ive found going to and from numpy (i.e. |
If you look in FlashInfer's
Yes; however H2D transfers are preferred over D2H as they can be done in a non-blocking fashion and do force synchronization with GPU. For the build call we are trying to optimize the CPU overhead so the fire-and-forget nature of the H2D transfers is better then depending on D2H transfer.
Thats effectively what this PR does; the CPU buffers in |
Oh my bad! Sorry, I was saying they are passing the device-side buffers.
I am wondering if we can override this plan function that lets the wrapper directly own the device-side persistent buffer from VLLM, and avoid any unnecessary copy (device-to-device or host-to-device)? At least for qo_indptr, which is equivalent to query_start_loc, we already have both cpu and gpu versions of it from common_attn_metadata, so we can just reuse them without any further copy. |
Is this what you are referring to? https://github.com/sgl-project/sglang/blob/719b29f218a09642193c4bda2a7ffa32829d5604/python/sglang/srt/layers/attention/flashinfer_backend.py#L1229 ?; not that familiar with sglang. This is an interesting idea; thanks for sharing! Regardless, even in this overridden version they pass host side buffers (https://github.com/sgl-project/sglang/blob/719b29f218a09642193c4bda2a7ffa32829d5604/python/sglang/srt/layers/attention/flashinfer_backend.py#L1334-L1336); so if we want to override plan in the future I think we would still want this PR as a stepping stone (and override plan in follow up PR). |
Could you make sure to test the trtllm case in the flashinfer backend as well? Just want to make sure this choice is preferable for that backend as well if affected |
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
Flash infer prefers host side CPU buffers in many cases, example: https://github.com/flashinfer-ai/flashinfer/blob/3c40456effae8b9c5b1a11c0d1e0594295b1a312/flashinfer/prefill.py#L1430-L1436
So we pass host side buffers (since #20466 we now have access to these) to reduce D2H transfers.
Trace from main showing D2H transfers in
plan
Test Plan
Test Result
Accuracy Results
Benchmark Results
Benchmark Command:
Results (3 runs per condition, mean ± standard error):
Tested on NVIDIA B200 GPU with meta-llama/Llama-3.2-3B-Instruct (256→128 tokens)
(Optional) Documentation Update